{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e12d4f35-6bf3-45ac-a68f-f5cf8f3f1694",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.svm import LinearSVC\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "\n",
    "from gensim.test.utils import common_texts\n",
    "from gensim.models import Word2Vec\n",
    "model = Word2Vec(sentences=common_texts, vector_size=100, window=5, min_count=1, workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "87638379-4e14-45a8-9b24-753bef0c99c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import jieba\n",
    "\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n",
    "train_data[0] = train_data[0].apply(jieba.lcut)\n",
    "test_data[0] = test_data[0].apply(jieba.lcut)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7fe6746d-c804-4627-bdbf-ef228b7d0b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec(\n",
    "    sentences=list(train_data[0].values[:]) + list(test_data[0].values[:]), \n",
    "vector_size=30, window=5, min_count=1, workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "942d78a3-15e2-4857-a13e-90945f5d04a4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('热水器', 0.98893141746521),\n",
       " ('关掉', 0.9859920740127563),\n",
       " ('加湿器', 0.9854778051376343),\n",
       " ('这时候', 0.9815304279327393),\n",
       " ('电脑', 0.9793143272399902),\n",
       " ('，', 0.9790096879005432),\n",
       " ('蒸', 0.975031852722168),\n",
       " ('洗衣机', 0.9734700918197632),\n",
       " ('准备', 0.9727935791015625),\n",
       " ('把', 0.9721974730491638)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.most_similar('打开')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f1761282-5026-407d-80ee-5ae6944d2d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_w2v = train_data[0].apply(lambda x: model.wv[x].mean(0))\n",
    "test_w2v = test_data[0].apply(lambda x: model.wv[x].mean(0))\n",
    "\n",
    "train_w2v = np.vstack(train_w2v)\n",
    "test_w2v = np.vstack(test_w2v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "53c2c463-6995-49d6-a6d3-2bb770f33b0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                       precision    recall  f1-score   support\n",
      "\n",
      "         Alarm-Update       0.91      0.93      0.92      1264\n",
      "           Audio-Play       0.00      0.00      0.00       226\n",
      "       Calendar-Query       0.93      0.97      0.95      1214\n",
      "        FilmTele-Play       0.56      0.63      0.59      1355\n",
      "HomeAppliance-Control       0.84      0.92      0.88      1215\n",
      "           Music-Play       0.73      0.78      0.75      1304\n",
      "                Other       0.10      0.01      0.02       214\n",
      "         Radio-Listen       0.85      0.84      0.85      1285\n",
      "       TVProgram-Play       0.69      0.05      0.09       240\n",
      "         Travel-Query       0.84      0.92      0.88      1220\n",
      "           Video-Play       0.70      0.73      0.71      1334\n",
      "        Weather-Query       0.82      0.83      0.83      1229\n",
      "\n",
      "             accuracy                           0.79     12100\n",
      "            macro avg       0.66      0.63      0.62     12100\n",
      "         weighted avg       0.76      0.79      0.77     12100\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cv_pred = cross_val_predict(\n",
    "    LinearSVC(),\n",
    "    train_w2v, train_data[1]\n",
    ")\n",
    "print(classification_report(train_data[1], cv_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
